import logging

from collections import defaultdict
from lab.reports import Table, CellFormatter
from downward.reports import PlanningReport

def toString(f):
    return str("{0:.1f}".format(f)) 

def valid(attribute):
    return (attribute != -1 and attribute is not None) 

class SummaryReport(PlanningReport):
    """
    If the experiment contains more than one algorithm, use
    ``filter_algorithm='my_algorithm'`` to select exactly two algorithms
    for the report. The algorithm need to contain the key words: astar and nbs.

    >>> from downward.experiment import FastDownwardExperiment
    >>> exp = FastDownwardExperiment()
    >>> exp.add_report(SummaryReport(
    ...     attributes=["expansions", "search_time"],
    ...     filter_algorithm=["astar_lmcut","nbs_lmcut"]))

    Output legend:
        'Actions:' --> minimum-(median)-maximum
        'Ex: Bi < Uni' --> # of problems where nbs expands less states.
        'Ex_jump: Bi < 2* Uni' --> same as above, with expansions before last jump.
        'Ex: algo:' --> min-(med)-max
        'h: B > F' --> # of problems h-value is bigger in backward direction at mp.
        'Initial Goals:' --> min-(med)-max
        'MP: B <-- F' --> # of times the mp is closer to the goal than start.
        'Ratio_ex_jump:' --> min-(med)-max --> before_jump_expanded / expanded
        'S: algo:' --> # of solved problems with corresponding algo

    """
    def __init__(self, **kwargs):
        PlanningReport.__init__(self, **kwargs)

    def _translate_mp(self, mp):
        if '-' in mp:
            return (0, 0)
        parts = mp.split(', ')
        for part in parts:
            if 'f_g: ' in part:
                f_g = int(part[5:])
            if 'b_g: ' in part:
                b_g = int(part[5:])
            if 'f_h: ' in part:
                f_h = int(part[5:])
            if 'b_h: ' in part:
                b_h = int(part[5:])
        h = 1 if (b_h >= f_h) else 0
        s = 1 if (b_g >= f_g) else 0
        return (h, s) 

    def _get_table(self):
        kwargs = dict(
            colored=True)
        table = Table(title="Domain", **kwargs) 

        # Variables
        ig = defaultdict(list)
        ex = defaultdict(list)
        ex_counter = {} 
        jump_ex_counter = {}
        h_counter = {}
        mp_counter = {}
        actions = defaultdict(list)
        ratio_ex = defaultdict(list)
        solved = {}
        formatter = CellFormatter(bold=True)

        # Set Properties
        for (domain, problem), runs in sorted(self.problem_runs.items()):
            ex_uni = -1
            ex_bi = -1
            jump_ex_uni = -1
            jump_ex_bi = -1
            for run in runs:
                algo = run.get('algorithm')
                if (run.get('error') == 'success'):
                    solved[domain,algo] = solved.get((domain,algo),0) + 1 
                if (valid(run.get('expanded'))):
                        ex[domain, algo].append(run.get('expanded'))
                if (algo == self.algoBi):
                    ex_bi = run.get('expanded')
                    mp = run.get('meeting_point','-')
                    (h, s) = self._translate_mp(mp)
                    h_counter[domain] = h_counter.get(domain, 0) + h
                    mp_counter[domain] = mp_counter.get(domain, 0) + s
                    jump_ex_bi = run.get('jump_expanded',-1)
                    if (valid(run.get('ratio_jump_expanded'))):
                        ratio_ex[domain].append(run.get('ratio_jump_expanded'))
                    if (valid(run.get('b_initial_goals'))):
                        ig[domain].append(run.get('b_initial_goals'))
                    actions[domain].append(run.get('ratio_fb_actions', 1.0))
                else:
                    ex_uni = run.get('expanded')
                    jump_ex_uni = run.get('jump_expanded',-1)
            if (ex_bi < ex_uni and valid(ex_bi)):
                ex_counter[domain] = ex_counter.get(domain,0) + 1
            if (jump_ex_bi < 2*jump_ex_uni and valid(jump_ex_bi)):
                jump_ex_counter[domain] = jump_ex_counter.get(domain,0) + 1

        # Write cells
        for (domain), values in ig.items():
            values.sort()
            table.add_cell(domain,'Initial Goals: ', str(values[0]) + "-(" + str(values[len(values)/2]) + ")-" + str(values[-1]))
        for (domain, algo), values in ex.items():
            values.sort()
            table.add_cell(domain,'Expansions: ' +algo, str(values[0]) + "-(" + str(values[len(values)/2]) + ")-" + str(values[-1]))
        for domain in self.domains.keys():
            table.add_cell(domain,'Ex: Bi < Uni', ex_counter.get(domain, 0))
            if ex_counter.get(domain,0) > 0:
                table.cell_formatters[domain]['Ex: Bi < Uni'] = formatter
        for domain in self.domains.keys():
            table.add_cell(domain,'Ex_jump: Bi < 2*Uni', jump_ex_counter.get(domain, 0))
        for (domain), values in actions.items():
            values.sort()
            table.add_cell(domain, 'Actions', toString(values[0]) + "-(" + toString(values[len(values)/2]) + ")-" + toString(values[-1]))
        for (domain), counter in h_counter.items():
            table.add_cell(domain, 'h: B > F', str(counter))
        for (domain), counter in mp_counter.items():
            table.add_cell(domain, 'MP: B <-- F', str(counter))
        for domain in self.domains.keys():
            for algo in self.algorithms:
                table.add_cell(domain, 'S: ' + algo, str(solved.get((domain,algo),0)))
            if (solved.get((domain,self.algoBi),0) > solved.get((domain,self.algoUni),0)):
                solved_row = 'S: ' + self.algoBi
            else:
                solved_row = 'S: ' + self.algoUni
            table.cell_formatters[domain][solved_row] = formatter
        for (domain), values in ratio_ex.items():
            values.sort()
            table.add_cell(domain, 'Ratio: Ex-Jumps',  toString(values[0]) + "-(" + toString(values[len(values)/2]) + ")-" + toString(values[-1]))
 

        return table


    def _check_algorithms(self):
        if len(self.algorithms) != 2:
            logging.critical('Summary Reports need exactly two algorithms.')                
        else:
            for algo in self.algorithms:
                if ('nbs' in algo):
                    self.algoBi = algo
                else:
                    self.algoUni = algo
        
    def get_markup(self):
        self._check_algorithms()        

        tables = [self._get_table()]       
        return '\n'.join(str(table) for table in tables)


# List of properties:
# domains               : (domain), problems
# problems              : set (domain, problem)
# problem_runs          : (domain, problem), runs
# domain_algorithm_runs : (domain, algorithm), runs
# runs                  : (domain, problem, algo), run
# attributes
# algorithms            : set (algorithm)
# algorithm_info
